import tensorflow as tf


class Model(tf.keras.Model):
    def __init__(self, mlp_units, query_ball_radius, num_max_neighbor=64, trainable=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.query_radius = query_ball_radius
        self.num_max_neighbor = num_max_neighbor
        self.mlp_list = []
        for mlp_cell in mlp_units:
            self.mlp_list.append(tf.keras.layers.Dense(mlp_cell, activation=tf.keras.activations.tanh,
                                                       trainable=trainable))
        self.last_fcn = tf.keras.layers.Dense(3, activation=None, trainable=trainable)

    def _query_ball(self, query_radius, num_max_neighbor, select_mask, all_pos, all_feature):
        '''
        build the feature of neighbor particles within a radius of select particles in all particles
        :param query_radius: Float. the radius of query ball
        :param num_max_neighbor: Int. the max num of neighbor particles of a particle
        :param select_mask: A bool array. The mask of select particles. So the select particles are 'all[select_mask]',
        which is [F, 3]
        :param all_pos: Float array, [N, 3]. The coordinate of all particles
        :param all_feature: Float array, [N, 4]. The feature of all particles
        :return: Int array, [F, max, 3+4]. The neighbor relative_pos+feature of each select particle in all particles.
        The relative_pos means {the pos of neighbor particles - the pos} of a select particle.
        ATTENTION, a particle cannot be the neighbor of itself. If the num of actual neighbors is out of max,
        then choose the nearest particles. If the num of actual neighbors is less than 'max',
        then fill the relative_pos domain [0, 5*r, 0], and fill the feature domain [0,0,0,0].
        '''
        neighbor_feature = None
        return neighbor_feature

    def pred(self, inputs):  # [N, 7]
        all_particle_feature = inputs
        pos = all_particle_feature[:, :3]  # [N, 3]
        feature = all_particle_feature[:, 3:]  # [N, 4]
        fluid_mask = all_particle_feature[:, 6] > 0  # [N]
        neighbor_feature = self._query_ball(self.query_radius, self.num_max_neighbor, fluid_mask, pos, feature)  # [F, m, 7]  # todo query_ball
        fluid_feature = feature[fluid_mask, :3]  # [F, 3]
        _tile_fluid_feature = tf.tile(tf.expand_dims(fluid_feature, 1), (1, self.num_max_neighbor, 1))  # [F, m, 3]
        feature_mlp = tf.concat((neighbor_feature, _tile_fluid_feature), axis=1)  # [F, m, 7+3]
        for mlp in self.mlp_list:
            feature_mlp = mlp(feature_mlp)
        pred = self.last_fcn(feature_mlp)  # [F, m, 3]
        pred = tf.reduce_max(pred, axis=1) + tf.reshape([0, -9.8, 0], (1, 3))  # [F, 3]
        return pred, fluid_mask

    def loss(self, pred, truth, fluid_mask):
        return tf.reduce_mean(tf.square(pred - truth[fluid_mask]))

    def train(self, inputs, outputs, optimizer):
        with tf.GradientTape() as t:
            pred, fluid_mask = self.pred(inputs)
            current_loss = self.loss(pred, outputs, fluid_mask)
            grads = t.gradient(current_loss, self.trainable_variables)
            optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return current_loss
